import cmd
from miasm2.core.utils import hexdump
import miasm2.jitter.csts as csts
from miasm2.jitter.jitload import ExceptionHandle


class DebugBreakpoint:

    "Debug Breakpoint parent class"
    pass


class DebugBreakpointSoft(DebugBreakpoint):

    "Stand for software breakpoint"

    def __init__(self, addr):
        self.addr = addr

    def __str__(self):
        return "Soft BP @0x%08x" % self.addr


class DebugBreakpointMemory(DebugBreakpoint):

    "Stand for memory breakpoint"

    type2str = {csts.BREAKPOINT_READ: "R",
                csts.BREAKPOINT_WRITE: "W"}

    def __init__(self, addr, size, access_type):
        self.addr = addr
        self.access_type = access_type
        self.size = size

    def __str__(self):
        bp_type = ""
        for k, v in self.type2str.items():
            if k & self.access_type != 0:
                bp_type += v
        return "Memory BP @0x%08x, Size 0x%08x, Type %s" % (self.addr,
                                                            self.size,
                                                            bp_type)

    @classmethod
    def get_access_type(cls, read=False, write=False):
        value = 0
        for k, v in cls.type2str.items():
            if v == "R" and read is True:
                value += k
            if v == "W" and write is True:
                value += k
        return value


class Debugguer(object):

    "Debugguer linked with a Jitter instance"

    def __init__(self, myjit):
        "myjit : jitter instance"
        self.myjit = myjit
        self.bp_list = []     # DebugBreakpointSoft list
        self.hw_bp_list = []  # DebugBreakpointHard list
        self.mem_watched = []  # Memory areas watched

    def init_run(self, addr):
        self.myjit.init_run(addr)

    def add_breakpoint(self, addr):
        "Add bp @addr"
        bp = DebugBreakpointSoft(addr)
        func = lambda x: bp
        bp.func = func
        self.bp_list.append(bp)
        self.myjit.add_breakpoint(addr, func)

    def init_memory_breakpoint(self):
        "Set exception handler on EXCEPT_BREAKPOINT_INTERN"
        self.myjit.exception_handler

    def add_memory_breakpoint(self, addr, size, read=False, write=False):
        "add mem bp @[addr, addr + size], on read/write/both"
        access_type = DebugBreakpointMemory.get_access_type(read=read,
                                                            write=write)
        dbm = DebugBreakpointMemory(addr, size, access_type)
        self.hw_bp_list.append(dbm)
        self.myjit.vm.vm_add_memory_breakpoint(addr, size, access_type)

    def remove_breakpoint(self, dbs):
        "remove the DebugBreakpointSoft instance"
        self.bp_list.remove(dbs)
        self.myjit.remove_breakpoints_by_callback(dbs.func)

    def remove_breakpoint_by_addr(self, addr):
        "remove breakpoints @ addr"
        for bp in self.get_breakpoint_by_addr(addr):
            self.remove_breakpoint(bp)

    def remove_memory_breakpoint(self, dbm):
        "remove the DebugBreakpointMemory instance"
        self.hw_bp_list.remove(dbm)
        self.myjit.vm.vm_remove_memory_breakpoint(dbm.addr, dbm.access_type)

    def remove_memory_breakpoint_by_addr_access(self, addr, read=False,
                                                write=False):
        "remove breakpoints @ addr"
        access_type = DebugBreakpointMemory.get_access_type(read=read,
                                                            write=write)
        for bp in self.hw_bp_list:
            if bp.addr == addr and bp.access_type == access_type:
                self.remove_memory_breakpoint(bp)

    def get_breakpoint_by_addr(self, addr):
        ret = []
        for dbgsoft in self.bp_list:
            if dbgsoft.addr == addr:
                ret.append(dbgsoft)
        return ret

    def get_breakpoints(self):
        return self.bp_list

    def active_trace(self, mn=None, regs=None, newbloc=None):
        if mn is not None:
            self.myjit.jit.log_mn = mn
        if regs is not None:
            self.myjit.jit.log_regs = regs
        if newbloc is not None:
            self.myjit.jit.log_newbloc = newbloc

    def handle_exception(self, res):
        if res is None:
            return

        if isinstance(res, DebugBreakpointSoft):
            print "Breakpoint reached @0x%08x" % res.addr
        elif isinstance(res, ExceptionHandle):
            if res == ExceptionHandle.memoryBreakpoint():
                print "Memory breakpoint reached!"

                # Remove flag
                except_flag = self.myjit.vm.vm_get_exception()
                self.myjit.vm.vm_set_exception(except_flag ^ res.except_flag)

            else:
                raise NotImplementedError("Unknown Except")
        else:
            raise NotImplementedError("type res")

    def step(self):
        "Step in jit"

        self.myjit.jit.set_options(jit_maxline=1)
        self.myjit.jit.updt_automod_code(self.myjit.vm, self.myjit.pc, 8)

        res = self.myjit.continue_run(step=True)
        self.handle_exception(res)

        self.myjit.jit.set_options(jit_maxline=50)
        self.on_step()

        return res

    def run(self):
        res = self.myjit.continue_run()
        self.handle_exception(res)
        return res

    def get_mem(self, addr, size=0xF):
        "hexdump @addr, size"

        hexdump(self.myjit.vm.vm_get_mem(addr, size))

    def get_mem_raw(self, addr, size=0xF):
        "hexdump @addr, size"
        return self.myjit.vm.vm_get_mem(addr, size)

    def watch_mem(self, addr, size=0xF):
        self.mem_watched.append((addr, size))

    def on_step(self):
        for addr, size in self.mem_watched:
            print "@0x%08x:" % addr
            self.get_mem(addr, size)

    def get_reg_value(self, reg_name):
        return getattr(self.myjit.cpu, reg_name)

    def set_reg_value(self, reg_name, value):

        # Handle PC case
        if reg_name == self.myjit.my_ir.pc.name:
            self.init_run(value)

        setattr(self.myjit.cpu, reg_name, value)

    def get_gpreg_all(self):
        "Return general purposes registers"
        return self.myjit.cpu.vm_get_gpreg()


class DebugCmd(cmd.Cmd, object):

    "CommandLineInterpreter for Debugguer instance"

    color_g = '\033[92m'
    color_e = '\033[0m'
    color_b = '\033[94m'
    color_r = '\033[91m'

    intro = color_g + "=== Miasm2 Debugging shell ===\nIf you need help, "
    intro += "type 'help' or '?'" + color_e
    prompt = color_b + "$> " + color_e

    def __init__(self, dbg):
        "dbg : Debugguer"
        self.dbg = dbg
        super(DebugCmd, self).__init__()

    # Debug methods

    def print_breakpoints(self):
        bp_list = self.dbg.bp_list
        if len(bp_list) == 0:
            print "No breakpoints."
        else:
            for i, b in enumerate(bp_list):
                print "%d\t0x%08x" % (i, b.addr)

    def print_watchmems(self):
        watch_list = self.dbg.mem_watched
        if len(watch_list) == 0:
            print "No memory watchpoints."
        else:
            print "Num\tAddress  \tSize"
            for i, w in enumerate(watch_list):
                addr, size = w
                print "%d\t0x%08x\t0x%08x" % (i, addr, size)

    def print_registers(self):
        regs = self.dbg.get_gpreg_all()

        # Display settings
        title1 = "Registers"
        title2 = "Values"
        max_name_len = max(map(len, regs.keys() + [title1]))

        # Print value table
        s = "%s%s    |    %s" % (
            title1, " " * (max_name_len - len(title1)), title2)
        print s
        print "-" * len(s)
        for name, value in sorted(regs.items(), key=lambda x: x[0]):
            print "%s%s    |    %s" % (name,
                                       " " * (max_name_len - len(name)),
                                       hex(value).replace("L", ""))

    def add_breakpoints(self, bp_addr):
        for addr in bp_addr:
            if "0x" in addr:
                addr = int(addr, 16)
            else:
                addr = int(addr)

            good = True
            for i, dbg_obj in enumerate(self.dbg.bp_list):
                if dbg_obj.addr == addr:
                    good = False
                    break
            if good is False:
                print "Breakpoint 0x%08x already set (%d)" % (addr, i)
            else:
                l = len(self.dbg.bp_list)
                self.dbg.add_breakpoint(addr)
                print "Breakpoint 0x%08x successfully added ! (%d)" % (addr, l)

    display_mode = {"mn": None,
                    "regs": None,
                    "newbloc": None}

    def update_display_mode(self):
        self.display_mode = {"mn": self.dbg.myjit.jit.log_mn,
                             "regs": self.dbg.myjit.jit.log_regs,
                             "newbloc": self.dbg.myjit.jit.log_newbloc}

    # Command line methods
    def print_warning(self, s):
        print self.color_r + s + self.color_e

    def onecmd(self, line):
        cmd_translate = {"h": "help",
                         "q": "exit",
                         "e": "exit",
                         "!": "exec",
                         "r": "run",
                         "i": "info",
                         "b": "breakpoint",
                         "s": "step",
                         "d": "dump"}

        if len(line) >= 2 and \
           line[1] == " " and \
           line[:1] in cmd_translate:
            line = cmd_translate[line[:1]] + line[1:]

        if len(line) == 1 and line in cmd_translate:
            line = cmd_translate[line]

        r = super(DebugCmd, self).onecmd(line)
        return r

    def can_exit(self):
        return True

    def do_display(self, arg):
        if arg == "":
            self.help_display()
            return

        args = arg.split(" ")
        if args[-1].lower() not in ["on", "off"]:
            self.print_warning("/!\ %s not in 'on' / 'off'" % args[-1])
            return
        mode = args[-1].lower() == "on"
        d = {}
        for a in args[:-1]:
            d[a] = mode
        self.dbg.active_trace(**d)
        self.update_display_mode()

    def help_display(self):
        print "Enable/Disable tracing."
        print "Usage: display <mode1> <mode2> ... on|off"
        print "Available modes are:"
        for k in self.display_mode:
            print "\t%s" % k
        print "Use 'info display' to get current values"

    def do_watchmem(self, arg):
        if arg == "":
            self.help_watchmem()
            return

        args = arg.split(" ")
        if len(args) >= 2:
            if "0x" in args[1]:
                size = int(args[1], 16)
            else:
                size = int(args[1])
        else:
            size = 0xF

        if "0x" in args[0]:
            addr = int(args[0], 16)
        else:
            addr = int(args[0])

        self.dbg.watch_mem(addr, size)

    def help_watchmem(self):
        print "Add a memory watcher."
        print "Usage: watchmem <addr> [size]"
        print "Use 'info watchmem' to get current memory watchers"

    def do_info(self, arg):
        av_info = ["registers",
                   "display",
                   "breakpoints",
                   "watchmem"]

        if arg == "":
            print "'info' must be followed by the name of an info command."
            print "List of info subcommands:"
            for k in av_info:
                print "\t%s" % k

        if arg.startswith("b"):
            # Breakpoint
            self.print_breakpoints()

        if arg.startswith("d"):
            # Display
            self.update_display_mode()
            for k, v in self.display_mode.items():
                print "%s\t\t%s" % (k, v)

        if arg.startswith("w"):
            # Watchmem
            self.print_watchmems()

        if arg.startswith("r"):
            # Registers
            self.print_registers()

    def help_info(self):
        print "Generic command for showing things about the program being"
        print "debugged. Use 'info' without arguments to get the list of"
        print "available subcommands."

    def do_breakpoint(self, arg):
        if arg == "":
            self.help_breakpoint()
        else:
            addrs = arg.split(" ")
            self.add_breakpoints(addrs)

    def help_breakpoint(self):
        print "Add breakpoints to argument addresses."
        print "Example:"
        print "\tbreakpoint 0x11223344"
        print "\tbreakpoint 1122 0xabcd"

    def do_step(self, arg):
        if arg == "":
            nb = 1
        else:
            nb = int(arg)
        for _ in xrange(nb):
            self.dbg.step()

    def help_step(self):
        print "Step program until it reaches a different source line."
        print "Argument N means do this N times (or till program stops"
        print "for another reason)."

    def do_dump(self, arg):
        if arg == "":
            self.help_dump()
        else:
            args = arg.split(" ")
            if len(args) >= 2:
                if "0x" in args[1]:
                    size = int(args[1], 16)
                else:
                    size = int(args[1])
            else:
                size = 0xF
            if "0x" in args[0]:
                addr = int(args[0], 16)
            else:
                addr = int(args[0])

            self.dbg.get_mem(addr, size)

    def help_dump(self):
        print "Dump <addr> [size]. Dump size bytes at addr."

    def do_run(self, arg):
        self.dbg.run()

    def help_run(self):
        print "Launch or continue the current program"

    def do_exit(self, s):
        return True

    def do_exec(self, l):
        try:
            print eval(l)
        except Exception, e:
            print "*** Error: %s" % e

    def help_exec(self):
        print "Exec a python command."
        print "You can also use '!' shortcut."

    def help_exit(self):
        print "Exit the interpreter."
        print "You can also use the Ctrl-D shortcut."

    def help_help(self):
        print "Print help"

    def postloop(self):
        print '\nGoodbye !'
        super(DebugCmd, self).postloop()

    do_EOF = do_exit
    help_EOF = help_exit
